
import torch
import numpy as np    
from torch import nn
from einops.layers.torch import Rearrange, Reduce


class Permute(nn.Module):
    """
        Perform Random Permutation

    Args:
        dim (int): C
        dim_token (int): S
        flip (bool) : swap axis after permutation or not
    """
    def __init__(self, dim, dim_token, flip=False):
        super().__init__()
        self.dim_token = dim_token
        self.dim = dim
        self.index_pre = np.random.permutation( dim*dim_token)
        self.flip = flip


    def forward(self, x):
        x = x.reshape(-1, self.dim_token*self.dim)
        x = x[:, self.index_pre]
        if self.flip:
            x = x.reshape(-1, self.dim, self.dim_token)
        else:
            x = x.reshape(-1, self.dim_token, self.dim)
        return  x


def Conjugation(dim:int, dim_token:int, fn, J:str=None, o_dim:int=-1, o_dim_token:int=-1):
    if o_dim <0: o_dim = dim
    if o_dim_token<0:o_dim_token = dim_token
    if J is None:
        return fn
    elif J  in ["Transpose", "transpose", "T"]:
        return nn.Sequential(
            Rearrange("b n d -> b d n"),
            fn,
            Rearrange("b d n -> b n d"),            
        )

    elif J in ["Permute", "Permutation", "P"]:
        return nn.Sequential(
            Permute(dim, dim_token),
            fn,
            Permute(o_dim, o_dim_token)
        )
    elif J in ["PermuteTranspose", "PT"]:
        return nn.Sequential(
            Permute(dim, dim_token, flip=True),
            fn,
            Permute(o_dim_token, o_dim, flip=True)
        )        
    else:
        raise ValueError                



